{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "977a9bb6",
   "metadata": {},
   "source": [
    "# Tutorial 02: Navier-Stokes problem\n",
    "\n",
    "In this tutorial we compare the formulation of a Navier-Stokes problem using standard assembly with mixed function spaces and block assembly. This tutorial serves as a reminder on how to solve a nonlinear problem in `dolfinx` interfacing to `SNES` (part of `PETSc`), and shows a further usage of `multiphenicsx.fem.petsc.BlockVecSubVectorWrapper`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bef0cb78",
   "metadata": {},
   "outputs": [],
   "source": [
    "import typing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b48b8449",
   "metadata": {},
   "outputs": [],
   "source": [
    "import basix.ufl\n",
    "import dolfinx.fem\n",
    "import dolfinx.fem.petsc\n",
    "import dolfinx.io\n",
    "import gmsh\n",
    "import mpi4py.MPI\n",
    "import numpy as np\n",
    "import numpy.typing\n",
    "import petsc4py.PETSc\n",
    "import ufl\n",
    "import viskex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30328ec0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import multiphenicsx.fem\n",
    "import multiphenicsx.fem.petsc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34a7fce6",
   "metadata": {},
   "source": [
    "### Constitutive parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45be824d",
   "metadata": {},
   "outputs": [],
   "source": [
    "nu = 0.01\n",
    "\n",
    "\n",
    "def u_in_eval(x: np.typing.NDArray[np.float64]) -> np.typing.NDArray[  # type: ignore[no-any-unimported]\n",
    "        petsc4py.PETSc.ScalarType]:\n",
    "    \"\"\"Return the flat velocity profile at the inlet.\"\"\"\n",
    "    values = np.zeros((2, x.shape[1]))\n",
    "    values[0, :] = 1.0\n",
    "    return values\n",
    "\n",
    "\n",
    "def u_wall_eval(x: np.typing.NDArray[np.float64]) -> np.typing.NDArray[  # type: ignore[no-any-unimported]\n",
    "        petsc4py.PETSc.ScalarType]:\n",
    "    \"\"\"Return the zero velocity at the wall.\"\"\"\n",
    "    return np.zeros((2, x.shape[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31c13354",
   "metadata": {},
   "source": [
    "### Geometrical parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc81d45b",
   "metadata": {},
   "outputs": [],
   "source": [
    "pre_step_length = 4.\n",
    "after_step_length = 14.\n",
    "pre_step_height = 3.\n",
    "after_step_height = 5.\n",
    "mesh_size = 1. / 5."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f68222b",
   "metadata": {},
   "source": [
    "### Mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6907c721",
   "metadata": {},
   "outputs": [],
   "source": [
    "gmsh.initialize()\n",
    "gmsh.model.add(\"mesh\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a299d77",
   "metadata": {},
   "outputs": [],
   "source": [
    "p0 = gmsh.model.geo.addPoint(0.0, after_step_height - pre_step_height, 0.0, mesh_size)\n",
    "p1 = gmsh.model.geo.addPoint(pre_step_length, after_step_height - pre_step_height, 0.0, mesh_size)\n",
    "p2 = gmsh.model.geo.addPoint(pre_step_length, 0.0, 0.0, mesh_size)\n",
    "p3 = gmsh.model.geo.addPoint(pre_step_length + after_step_length, 0.0, 0.0, mesh_size)\n",
    "p4 = gmsh.model.geo.addPoint(pre_step_length + after_step_length, after_step_height, 0.0, mesh_size)\n",
    "p5 = gmsh.model.geo.addPoint(0.0, after_step_height, 0.0, mesh_size)\n",
    "l0 = gmsh.model.geo.addLine(p0, p1)\n",
    "l1 = gmsh.model.geo.addLine(p1, p2)\n",
    "l2 = gmsh.model.geo.addLine(p2, p3)\n",
    "l3 = gmsh.model.geo.addLine(p3, p4)\n",
    "l4 = gmsh.model.geo.addLine(p4, p5)\n",
    "l5 = gmsh.model.geo.addLine(p5, p0)\n",
    "line_loop = gmsh.model.geo.addCurveLoop([l0, l1, l2, l3, l4, l5])\n",
    "domain = gmsh.model.geo.addPlaneSurface([line_loop])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd22b2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "gmsh.model.geo.synchronize()\n",
    "gmsh.model.addPhysicalGroup(1, [l5], 1)\n",
    "gmsh.model.addPhysicalGroup(1, [l0, l1, l2, l4], 2)\n",
    "gmsh.model.addPhysicalGroup(2, [domain], 0)\n",
    "gmsh.model.mesh.generate(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaa2e68b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mesh, subdomains, boundaries = dolfinx.io.gmshio.model_to_mesh(\n",
    "    gmsh.model, comm=mpi4py.MPI.COMM_WORLD, rank=0, gdim=2)\n",
    "gmsh.finalize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5aa0382-47b5-4dcf-a1e0-a1b32699245a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create connectivities required by the rest of the code\n",
    "mesh.topology.create_connectivity(mesh.topology.dim - 1, mesh.topology.dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9c3fed0",
   "metadata": {},
   "outputs": [],
   "source": [
    "boundaries_1 = boundaries.indices[boundaries.values == 1]\n",
    "boundaries_2 = boundaries.indices[boundaries.values == 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7bb0107e",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh(mesh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06aa4df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_mesh_tags(mesh, boundaries, \"boundaries\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da33800f",
   "metadata": {},
   "source": [
    "### Function spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d33248cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "V_element = basix.ufl.element(\"Lagrange\", mesh.basix_cell(), 2, shape=(mesh.geometry.dim, ))\n",
    "Q_element = basix.ufl.element(\"Lagrange\", mesh.basix_cell(), 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "815996fc",
   "metadata": {},
   "source": [
    "### Standard formulation using a mixed function space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b831e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_monolithic() -> dolfinx.fem.Function:\n",
    "    \"\"\"Run standard formulation using a mixed function space.\"\"\"\n",
    "    # Function spaces\n",
    "    W_element = basix.ufl.mixed_element([V_element, Q_element])\n",
    "    W = dolfinx.fem.functionspace(mesh, W_element)\n",
    "\n",
    "    # Test and trial functions: monolithic\n",
    "    vq = ufl.TestFunction(W)\n",
    "    (v, q) = ufl.split(vq)\n",
    "    dup = ufl.TrialFunction(W)\n",
    "    up = dolfinx.fem.Function(W)\n",
    "    (u, p) = ufl.split(up)\n",
    "\n",
    "    # Variational forms\n",
    "    F = (nu * ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx\n",
    "         + ufl.inner(ufl.grad(u) * u, v) * ufl.dx\n",
    "         - ufl.inner(p, ufl.div(v)) * ufl.dx\n",
    "         + ufl.inner(ufl.div(u), q) * ufl.dx)\n",
    "    J = ufl.derivative(F, up, dup)\n",
    "\n",
    "    # Boundary conditions\n",
    "    W0 = W.sub(0)\n",
    "    V, _ = W0.collapse()\n",
    "    u_in = dolfinx.fem.Function(V)\n",
    "    u_in.interpolate(u_in_eval)\n",
    "    u_wall = dolfinx.fem.Function(V)\n",
    "    u_wall.interpolate(u_wall_eval)\n",
    "    bdofs_V_1 = dolfinx.fem.locate_dofs_topological((W0, V), mesh.topology.dim - 1, boundaries_1)\n",
    "    bdofs_V_2 = dolfinx.fem.locate_dofs_topological((W0, V), mesh.topology.dim - 1, boundaries_2)\n",
    "    inlet_bc = dolfinx.fem.dirichletbc(u_in, bdofs_V_1, W0)\n",
    "    wall_bc = dolfinx.fem.dirichletbc(u_wall, bdofs_V_2, W0)\n",
    "    bc = [inlet_bc, wall_bc]\n",
    "\n",
    "    # Class for interfacing with SNES\n",
    "    class NavierStokesProblem:\n",
    "        \"\"\"Define a nonlinear problem, interfacing with SNES.\"\"\"\n",
    "\n",
    "        def __init__(  # type: ignore[no-any-unimported]\n",
    "            self, F: ufl.Form, J: ufl.Form, solution: dolfinx.fem.Function,\n",
    "            bcs: list[dolfinx.fem.DirichletBC], P: typing.Optional[ufl.Form] = None\n",
    "        ) -> None:\n",
    "            self._F = dolfinx.fem.form(F)\n",
    "            self._J = dolfinx.fem.form(J)\n",
    "            self._obj_vec = dolfinx.fem.petsc.create_vector(self._F)\n",
    "            self._solution = solution\n",
    "            self._bcs = bcs\n",
    "            self._P = P\n",
    "\n",
    "        def create_snes_solution(self) -> petsc4py.PETSc.Vec:  # type: ignore[no-any-unimported]\n",
    "            \"\"\"\n",
    "            Create a petsc4py.PETSc.Vec to be passed to petsc4py.PETSc.SNES.solve.\n",
    "\n",
    "            The returned vector will be initialized with the initial guess provided in `self._solution`.\n",
    "            \"\"\"\n",
    "            x = self._solution.x.petsc_vec.copy()\n",
    "            with x.localForm() as _x, self._solution.x.petsc_vec.localForm() as _solution:\n",
    "                _x[:] = _solution\n",
    "            return x\n",
    "\n",
    "        def update_solution(self, x: petsc4py.PETSc.Vec) -> None:  # type: ignore[no-any-unimported]\n",
    "            \"\"\"Update `self._solution` with data in `x`.\"\"\"\n",
    "            x.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "            with x.localForm() as _x, self._solution.x.petsc_vec.localForm() as _solution:\n",
    "                _solution[:] = _x\n",
    "\n",
    "        def obj(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec\n",
    "        ) -> np.float64:\n",
    "            \"\"\"Compute the norm of the residual.\"\"\"\n",
    "            self.F(snes, x, self._obj_vec)\n",
    "            return self._obj_vec.norm()  # type: ignore[no-any-return]\n",
    "\n",
    "        def F(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec, F_vec: petsc4py.PETSc.Vec\n",
    "        ) -> None:\n",
    "            \"\"\"Assemble the residual.\"\"\"\n",
    "            self.update_solution(x)\n",
    "            with F_vec.localForm() as F_vec_local:\n",
    "                F_vec_local.set(0.0)\n",
    "            dolfinx.fem.petsc.assemble_vector(F_vec, self._F)\n",
    "            dolfinx.fem.petsc.apply_lifting(F_vec, [self._J], [self._bcs], x0=[x], alpha=-1.0)\n",
    "            F_vec.ghostUpdate(addv=petsc4py.PETSc.InsertMode.ADD, mode=petsc4py.PETSc.ScatterMode.REVERSE)\n",
    "            dolfinx.fem.petsc.set_bc(F_vec, self._bcs, x, -1.0)\n",
    "\n",
    "        def J(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec, J_mat: petsc4py.PETSc.Mat,\n",
    "            P_mat: petsc4py.PETSc.Mat\n",
    "        ) -> None:\n",
    "            \"\"\"Assemble the jacobian.\"\"\"\n",
    "            J_mat.zeroEntries()\n",
    "            dolfinx.fem.petsc.assemble_matrix(  # type: ignore[misc]\n",
    "                J_mat, self._J, self._bcs, diagonal=1.0)  # type: ignore[arg-type]\n",
    "            J_mat.assemble()\n",
    "            if self._P is not None:\n",
    "                P_mat.zeroEntries()\n",
    "                dolfinx.fem.petsc.assemble_matrix(  # type: ignore[misc]\n",
    "                    P_mat, self._P, self._bcs, diagonal=1.0)  # type: ignore[arg-type]\n",
    "                P_mat.assemble()\n",
    "\n",
    "    # Create problem\n",
    "    problem = NavierStokesProblem(F, J, up, bc)\n",
    "    F_vec = dolfinx.fem.petsc.create_vector(problem._F)\n",
    "    J_mat = dolfinx.fem.petsc.create_matrix(problem._J)\n",
    "\n",
    "    # Solve\n",
    "    snes = petsc4py.PETSc.SNES().create(mesh.comm)\n",
    "    snes.setTolerances(max_it=20)\n",
    "    snes.getKSP().setType(\"preonly\")\n",
    "    snes.getKSP().getPC().setType(\"lu\")\n",
    "    snes.getKSP().getPC().setFactorSolverType(\"mumps\")\n",
    "    snes.setObjective(problem.obj)\n",
    "    snes.setFunction(problem.F, F_vec)\n",
    "    snes.setJacobian(problem.J, J=J_mat, P=None)\n",
    "    snes.setMonitor(lambda _, it, residual: print(it, residual))\n",
    "    up_copy = problem.create_snes_solution()\n",
    "    snes.solve(None, up_copy)\n",
    "    problem.update_solution(up_copy)  # TODO can this be safely removed?\n",
    "    up_copy.destroy()\n",
    "    snes.destroy()\n",
    "    return up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913136b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "up_m = run_monolithic()\n",
    "(u_m, p_m) = (up_m.sub(0).collapse(), up_m.sub(1).collapse())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfddc182",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(u_m, \"u\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3578582",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(u_m, \"u\", glyph_factor=1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f32353ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(p_m, \"p\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc2efeb8",
   "metadata": {},
   "source": [
    "### Block formulation using two independent function spaces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9a4fe94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_block() -> tuple[dolfinx.fem.Function, dolfinx.fem.Function]:\n",
    "    \"\"\"Run block formulation using two independent function spaces.\"\"\"\n",
    "    # Function spaces\n",
    "    V = dolfinx.fem.functionspace(mesh, V_element)\n",
    "    Q = dolfinx.fem.functionspace(mesh, Q_element)\n",
    "\n",
    "    # Test and trial functions\n",
    "    (v, q) = (ufl.TestFunction(V), ufl.TestFunction(Q))\n",
    "    (du, dp) = (ufl.TrialFunction(V), ufl.TrialFunction(Q))\n",
    "    (u, p) = (dolfinx.fem.Function(V), dolfinx.fem.Function(Q))\n",
    "\n",
    "    # Variational forms\n",
    "    F = [(nu * ufl.inner(ufl.grad(u), ufl.grad(v)) * ufl.dx + ufl.inner(ufl.grad(u) * u, v) * ufl.dx\n",
    "          - ufl.inner(p, ufl.div(v)) * ufl.dx),\n",
    "         ufl.inner(ufl.div(u), q) * ufl.dx]\n",
    "    J = [[ufl.derivative(F[0], u, du), ufl.derivative(F[0], p, dp)],\n",
    "         [ufl.derivative(F[1], u, du), ufl.derivative(F[1], p, dp)]]\n",
    "\n",
    "    # Boundary conditions\n",
    "    u_in = dolfinx.fem.Function(V)\n",
    "    u_in.interpolate(u_in_eval)\n",
    "    u_wall = dolfinx.fem.Function(V)\n",
    "    u_wall.interpolate(u_wall_eval)\n",
    "    bdofs_V_1 = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundaries_1)\n",
    "    bdofs_V_2 = dolfinx.fem.locate_dofs_topological(V, mesh.topology.dim - 1, boundaries_2)\n",
    "    inlet_bc = dolfinx.fem.dirichletbc(u_in, bdofs_V_1)\n",
    "    wall_bc = dolfinx.fem.dirichletbc(u_wall, bdofs_V_2)\n",
    "    bc = [inlet_bc, wall_bc]\n",
    "\n",
    "    # Class for interfacing with SNES\n",
    "    class NavierStokesProblem:\n",
    "        \"\"\"Define a nonlinear problem, interfacing with SNES.\"\"\"\n",
    "\n",
    "        def __init__(  # type: ignore[no-any-unimported]\n",
    "            self, F: list[ufl.Form], J: list[list[ufl.Form]],\n",
    "            solutions: tuple[dolfinx.fem.Function, dolfinx.fem.Function],\n",
    "            bcs: list[dolfinx.fem.DirichletBC],\n",
    "            P: typing.Optional[list[list[ufl.Form]]] = None\n",
    "        ) -> None:\n",
    "            self._F = dolfinx.fem.form(F)\n",
    "            self._J = dolfinx.fem.form(J)\n",
    "            self._obj_vec = dolfinx.fem.petsc.create_vector_block(self._F)\n",
    "            self._solutions = solutions\n",
    "            self._bcs = bcs\n",
    "            self._P = P\n",
    "\n",
    "        def create_snes_solution(self) -> petsc4py.PETSc.Vec:  # type: ignore[no-any-unimported]\n",
    "            \"\"\"\n",
    "            Create a petsc4py.PETSc.Vec to be passed to petsc4py.PETSc.SNES.solve.\n",
    "\n",
    "            The returned vector will be initialized with the initial guesses provided in `self._solutions`,\n",
    "            properly stacked together in a single block vector.\n",
    "            \"\"\"\n",
    "            x = dolfinx.fem.petsc.create_vector_block(self._F)\n",
    "            with multiphenicsx.fem.petsc.BlockVecSubVectorWrapper(x, [V.dofmap, Q.dofmap]) as x_wrapper:\n",
    "                for x_wrapper_local, sub_solution in zip(x_wrapper, self._solutions):\n",
    "                    with sub_solution.x.petsc_vec.localForm() as sub_solution_local:\n",
    "                        x_wrapper_local[:] = sub_solution_local\n",
    "            return x\n",
    "\n",
    "        def update_solutions(self, x: petsc4py.PETSc.Vec) -> None:  # type: ignore[no-any-unimported]\n",
    "            \"\"\"Update `self._solutions` with data in `x`.\"\"\"\n",
    "            x.ghostUpdate(addv=petsc4py.PETSc.InsertMode.INSERT, mode=petsc4py.PETSc.ScatterMode.FORWARD)\n",
    "            with multiphenicsx.fem.petsc.BlockVecSubVectorWrapper(x, [V.dofmap, Q.dofmap]) as x_wrapper:\n",
    "                for x_wrapper_local, sub_solution in zip(x_wrapper, self._solutions):\n",
    "                    with sub_solution.x.petsc_vec.localForm() as sub_solution_local:\n",
    "                        sub_solution_local[:] = x_wrapper_local\n",
    "\n",
    "        def obj(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec\n",
    "        ) -> np.float64:\n",
    "            \"\"\"Compute the norm of the residual.\"\"\"\n",
    "            self.F(snes, x, self._obj_vec)\n",
    "            return self._obj_vec.norm()  # type: ignore[no-any-return]\n",
    "\n",
    "        def F(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec, F_vec: petsc4py.PETSc.Vec\n",
    "        ) -> None:\n",
    "            \"\"\"Assemble the residual.\"\"\"\n",
    "            self.update_solutions(x)\n",
    "            with F_vec.localForm() as F_vec_local:\n",
    "                F_vec_local.set(0.0)\n",
    "            dolfinx.fem.petsc.assemble_vector_block(  # type: ignore[misc]\n",
    "                F_vec, self._F, self._J, self._bcs, x0=x, alpha=-1.0)\n",
    "\n",
    "        def J(  # type: ignore[no-any-unimported]\n",
    "            self, snes: petsc4py.PETSc.SNES, x: petsc4py.PETSc.Vec, J_mat: petsc4py.PETSc.Mat,\n",
    "            P_mat: petsc4py.PETSc.Mat\n",
    "        ) -> None:\n",
    "            \"\"\"Assemble the jacobian.\"\"\"\n",
    "            J_mat.zeroEntries()\n",
    "            dolfinx.fem.petsc.assemble_matrix_block(  # type: ignore[misc]\n",
    "                J_mat, self._J, self._bcs, diagonal=1.0)  # type: ignore[arg-type]\n",
    "            J_mat.assemble()\n",
    "            if self._P is not None:\n",
    "                P_mat.zeroEntries()\n",
    "                dolfinx.fem.petsc.assemble_matrix_block(  # type: ignore[misc]\n",
    "                    P_mat, self._P, self._bcs, diagonal=1.0)  # type: ignore[arg-type]\n",
    "                P_mat.assemble()\n",
    "\n",
    "    # Create problem\n",
    "    problem = NavierStokesProblem(F, J, (u, p), bc)\n",
    "    F_vec = dolfinx.fem.petsc.create_vector_block(problem._F)\n",
    "    J_mat = dolfinx.fem.petsc.create_matrix_block(problem._J)\n",
    "\n",
    "    # Solve\n",
    "    snes = petsc4py.PETSc.SNES().create(mesh.comm)\n",
    "    snes.setTolerances(max_it=20)\n",
    "    snes.getKSP().setType(\"preonly\")\n",
    "    snes.getKSP().getPC().setType(\"lu\")\n",
    "    snes.getKSP().getPC().setFactorSolverType(\"mumps\")\n",
    "    snes.setObjective(problem.obj)\n",
    "    snes.setFunction(problem.F, F_vec)\n",
    "    snes.setJacobian(problem.J, J=J_mat, P=None)\n",
    "    snes.setMonitor(lambda _, it, residual: print(it, residual))\n",
    "    solution = problem.create_snes_solution()\n",
    "    snes.solve(None, solution)\n",
    "    problem.update_solutions(solution)  # TODO can this be safely removed?\n",
    "    solution.destroy()\n",
    "    snes.destroy()\n",
    "    return (u, p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71ab9913",
   "metadata": {},
   "outputs": [],
   "source": [
    "(u_b, p_b) = run_block()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf272d0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_vector_field(u_b, \"u\", glyph_factor=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01d497dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "viskex.dolfinx.plot_scalar_field(p_b, \"p\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b94f7b2f",
   "metadata": {},
   "source": [
    "### Error computation between mixed and block cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06694128",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_error(\n",
    "    u_m: dolfinx.fem.Function, p_m: dolfinx.fem.Function, u_b: dolfinx.fem.Function, p_b: dolfinx.fem.Function\n",
    ") -> None:\n",
    "    \"\"\"Compute errors between the mixed and block cases.\"\"\"\n",
    "    u_m_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(ufl.grad(u_m), ufl.grad(u_m)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    err_u_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(\n",
    "            dolfinx.fem.form(ufl.inner(ufl.grad(u_b - u_m), ufl.grad(u_b - u_m)) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    p_m_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(p_m, p_m) * ufl.dx)), op=mpi4py.MPI.SUM))\n",
    "    err_p_norm = np.sqrt(mesh.comm.allreduce(\n",
    "        dolfinx.fem.assemble_scalar(dolfinx.fem.form(ufl.inner(p_b - p_m, p_b - p_m) * ufl.dx)),\n",
    "        op=mpi4py.MPI.SUM))\n",
    "    print(\"Relative error for velocity component is equal to\", err_u_norm / u_m_norm)\n",
    "    print(\"Relative error for pressure component is equal to\", err_p_norm / p_m_norm)\n",
    "    assert np.isclose(err_u_norm / u_m_norm, 0., atol=1.e-10)\n",
    "    assert np.isclose(err_p_norm / p_m_norm, 0., atol=1.e-10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "003b5f6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "run_error(u_m, p_m, u_b, p_b)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
